#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 17 11:44:57 2024

@author: jcfq2
"""

import os

jwst_dir='/Users/jcfq2/data/observations/jwst'

os.chdir(jwst_dir)

import matplotlib.pyplot as plt
from astropy.visualization import make_lupton_rgb
from astropy.table import Table
import ch4_fiddlesticks as ch4
import pandas as pd
import h3ppy
import glob
import spiceypy as spice
from astropy.io import fits
import numpy as np
from importlib import reload
import JWSTSolarSystemPointing as jssp
reload(jssp)
#import jwst_uranus as jwstu

# In[0]

kernel_dir = '/Users/jcfq2/data/observations/jwst/kernels/'
jssp.load_kernels(kdir=kernel_dir)

# Set up a h3ppy object too, always useful
h3p_model = h3ppy.h3p()

# In[1]

fn = 'jw05308001001_03119_g395h-f290lp_s3d.fits'
file_dir = '/Users/jcfq2/data/observations/jwst/5308_dither_combined/'

file = file_dir + fn
# Create a JWSTSolarSystemPointing object that helps with lots of things JWST
geo = jssp.JWSTSolarSystemPointing(file)

# Caluclate the geometry for the full IFU cube. This is a three dimensional arrray.
cube = geo.full_fov()

# Get the wavelength scale for this observation - should be the same for all NIRSpec G395H/F290L observations.
wave = geo.get_wavelength()

# These are the available geometric parameters - if you need something that's not here, let me know!
# Only using Pandas to make this table pretty, generally not a fan
pd.DataFrame(geo.keys, columns=["Parameter"])


# %%


from JWSTSolarSystemPointing import get_pixel_polygons as gpp
from JWSTSolarSystemPointing import get_pixel_polygons as gpprd
radec=gpprd(geo)#.get_pixel_polygons_radec()
polyedge=gpp(geo)#.get_pixel_polygons_radec()

ra_corners=radec[0]
dec_corners=radec[1]

latlong_corners=polyedge[0]
latlong_pixel_corners=polyedge[1]




# In[9] this displays a single image but can't do three colour or coadd


files = sorted(glob.glob(
    "/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))

# files=files[10]

lens=len(files)
lens=1



for i in range(lens):
    # file = 'data/jw03665032001_02101_g395h-f290lp_s3d.fits'
    geo = jssp.JWSTSolarSystemPointing(files[i])
    wave = geo.get_wavelength()
    cube = geo.full_fov()
    ra, dec = geo.get_delta_ra_dec_arcsec()
    spec = geo.convert(wave, geo.im[:, 25, 25])
    
    
    sy=2
    ey=-2
    geoim=geo.im[:, :,sy:ey]
    ra = ra[:,sy:ey]
    dec=dec[:,sy:ey]
    wavemin = 3.0
    wavemax = 3.2
    
    # dlambda=0.00015
    #whw = np.argwhere((wave > 3.3529+dlambda) & (wave < 3.3535+dlambda)).flatten()
    whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
    im_reflect = np.nanmedian(geoim[whw, :, :], axis=0)
    
    im_methane = geoim[2775, :, :]
    
    wavemin = 3.9529
    wavemax = 3.9535
    
    whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
    im_h3p = np.nanmedian(geoim[whw, :, :], axis=0)
    
    ab=geoim[2239,:,:]+geoim[1568,:,:]+geoim[1642,:,:]+geoim[1010,:,:]+geoim[1011,:,:]+geoim[1018,:,:]+geoim[1019,:,:]
    cd=(geoim[2241,:,:]*0.8*0.5+geoim[2236,:,:]*1.19*0.5)+((geoim[1571,:,:]+geoim[1565,:,:])*0.5)+(0.4*geoim[1640,:,:]/5.21)+(geoim[1004,:,:]+geoim[1005,:,:]*2)

    im_h3p = ab-cd
    
    #ra, dec = geo.get_delta_ra_dec_arcsec()
    
    # im_h3p[im_h3p > np.nanmedian(im_h3p)*3] = np.nanmedian(im_h3p)*3
    # im_h3p[im_h3p<0] = 0
    
    
    
    im_heat = geoim[3487,:,:]
    
    image_g = np.nan_to_num(im_h3p)
    image_r = np.nan_to_num(im_heat)  # /np.max(im_methane)
    image_b = np.nan_to_num(im_methane)  # /np.max(im_reflect)
    image = make_lupton_rgb(image_r/np.max(image_r),
                            image_g/np.max(image_g)*3, image_b/np.max(image_b)*4)
    
    plt.pcolormesh(ra, dec, image_g**0.1)
    plt.xlim([-5,5])
    plt.ylim([2,12])
    plt.show()
# plt.imshow(image)
print(np.max(image_r))
print(np.max(image_g))
print(np.max(image_b))




# In[9] plotting movie
from JWSTSolarSystemPointing import get_pixel_polygons as jssp_gpp

from pypolyclip import clip_multi, clip_single
 

xscale=100
yscale=100

#450 for 30 degrees of wrap around either size
# naxis is weird, if it is too small it skips anything outside the range, but can be any size bigger than the range of values clipped

naxis = (10*xscale,10*yscale)

naxis_mapcount=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_h3p=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_methane=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_heat=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_reflect=np.zeros((int(10*xscale+1),int(10*yscale+1)))


files = sorted(glob.glob(
    "/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))

# files=files[10]

lens=len(files)
# lens=5
z0=0
zz=0
x_offset=5
y_offset=-3

savedir="/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/movies/"

for i in range(lens):
    # file = 'data/jw03665032001_02101_g395h-f290lp_s3d.fits'
    geo = jssp.JWSTSolarSystemPointing(files[i])
    wave = geo.get_wavelength()
    cube = geo.full_fov()
    spec = geo.convert(wave, geo.im[:, 25, 25])
    cube = geo.full_fov()
    cube_1 = geo.full_fov(corner=1)
    cube_2 = geo.full_fov(corner=2)
    cube_3 = geo.full_fov(corner=3)
    cube_4 = geo.full_fov(corner=4)
    


    ra = (cube[13, :, :]-geo.ra_target)* -3600.0         *xscale +(x_offset*xscale)
    ra_1 = (cube_1[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
    ra_2 = (cube_2[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
    ra_3 = (cube_3[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
    ra_4 = (cube_4[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
   
    dec = (cube[14, :, :]-geo.dec_target)* 3600.0       *yscale +(y_offset*yscale)
    dec_1 = (cube_1[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)
    dec_2 = (cube_2[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)
    dec_3 = (cube_3[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)
    dec_4 = (cube_4[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)


    sy=2
    ey=-2
    geoim=geo.im[:, :,sy:ey]
    ra = ra[:,sy:ey]
    dec=dec[:,sy:ey]
    ra_1 = ra_1[:,sy:ey]
    dec_1=dec_1[:,sy:ey]
    ra_2 = ra_2[:,sy:ey]
    dec_2=dec_2[:,sy:ey]
    ra_3 = ra_3[:,sy:ey]
    dec_3=dec_3[:,sy:ey]
    ra_4 = ra_4[:,sy:ey]
    dec_4=dec_4[:,sy:ey]
    wavemin = 3.0
    wavemax = 3.2
    
    # dlambda=0.00015
    #whw = np.argwhere((wave > 3.3529+dlambda) & (wave < 3.3535+dlambda)).flatten()
    whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
    img_reflect = np.nanmedian(geoim[whw, :, :], axis=0)
    
    img_methane = geoim[2775, :, :]
    
    wavemin = 3.9529
    wavemax = 3.9535
    
    whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
    img_h3p = np.nanmedian(geoim[whw, :, :], axis=0)
    
    ab=geoim[2239,:,:]+geoim[1568,:,:]+geoim[1642,:,:]+geoim[1010,:,:]+geoim[1011,:,:]+geoim[1018,:,:]+geoim[1019,:,:]
    cd=(geoim[2241,:,:]*0.8*0.5+geoim[2236,:,:]*1.19*0.5)+((geoim[1571,:,:]+geoim[1565,:,:])*0.5)+(0.4*geoim[1640,:,:]/5.21)+(geoim[1004,:,:]+geoim[1005,:,:]*2)

    img_h3p = ab-cd
    
    img_heat = geoim[3487,:,:]

    #ra, dec = geo.get_delta_ra_dec_arcsec()
    
    # im_h3p[im_h3p > np.nanmedian(im_h3p)*3] = np.nanmedian(im_h3p)*3
    # im_h3p[im_h3p<0] = 0
    
    i0=0

    for xx in range(ab[:,0].size):
        for yy in range(ab[0,:].size):
        
            
            
            pixel_h3p=img_h3p[xx,yy]
            pixel_methane=img_methane[xx,yy]
            pixel_heat=img_heat[xx,yy]
            pixel_reflect=img_reflect[xx,yy]
            # print(xx,yy,pixel_int)

            if ~np.isnan(pixel_h3p):
                
                pxx = np.array([ra_1[xx,yy],ra_2[xx,yy],ra_3[xx,yy],ra_4[xx,yy]])
                pyy = np.array([dec_1[xx,yy],dec_2[xx,yy],dec_3[xx,yy],dec_4[xx,yy]])
                pii_h3p=pixel_h3p
                pii_methane=pixel_methane
                pii_heat=pixel_heat
                pii_reflect=pixel_reflect


                if i0 == 0:

                    px=pxx
                    py=pyy
                    pi_h3p=pii_h3p
                    pi_methane=pii_methane
                    pi_heat=pii_heat
                    pi_reflect=pii_reflect

                else: #i0 is set to one on first iteration, if not set, then make px and py

                    px= np.vstack([px, pxx])
                    py= np.vstack([py, pyy])
                    pi_h3p= np.hstack([pi_h3p, pii_h3p])
                    pi_methane= np.hstack([pi_methane, pii_methane])
                    pi_heat= np.hstack([pi_heat, pii_heat])
                    pi_reflect= np.hstack([pi_reflect, pii_reflect])

                i0=i0+1
                    # print(xx,yy,pxx,pyy,pii)

                # print(i0,pxx,pyy)

    xc, yc, area, slices = clip_multi(px, py, naxis)
 
#     # compute the total area by summing over all the pixels for each polygon

    A0 = np.asarray([np.sum(area[s]) for s in slices], dtype=float)
 
#     # make arrays to capture the values

    m_counts=np.zeros_like(area)
    m_h3p=np.zeros_like(area)
    m_methane=np.zeros_like(area)
    m_heat=np.zeros_like(area)
    m_reflect=np.zeros_like(area)
 
#     # itrator for slices/pixels (where slices are bunched by pixels)

    ixel=0

    for s in slices: 

        m_counts[s] = area[s]#/np.sum(area[s])
        # pypolyclip gives a value that is the proportion of the lat long pixel (which might be right?), here, I instead use the proportion of the pixel, so all latlong subpixels sum to a count of 1 - might be the wrong way to do this, but biases towards the smallest (most face on) pixels.  Needs tetsing!
        m_h3p[s] = pi_h3p[ixel]*area[s]#(area[s]/np.sum(area[s]))
        m_methane[s] = pi_methane[ixel]*area[s]#(area[s]/np.sum(area[s]))
        m_heat[s] = pi_heat[ixel]*area[s]#(area[s]/np.sum(area[s]))
        m_reflect[s] = pi_reflect[ixel]*area[s]#(area[s]/np.sum(area[s]))
        ixel=ixel+1
 
    # fill individual map positions with the above

    if z0 < 4:
        print(z0)
    # naxis_mapcount[yc,xc]=naxis_mapcount[yc,xc]+m_counts
    # naxis_h3p[yc,xc]=naxis_h3p[yc,xc]+m_int
        naxis_mapcount[yc,xc]=m_counts+naxis_mapcount[yc,xc]
        naxis_h3p[yc,xc]=m_h3p+naxis_h3p[yc,xc]
        naxis_methane[yc,xc]=m_methane+naxis_methane[yc,xc]
        naxis_heat[yc,xc]=m_heat+naxis_heat[yc,xc]
        naxis_reflect[yc,xc]=m_reflect+naxis_reflect[yc,xc]
        z0=z0+1
    else:
        z0=0
        zz=zz+1
        image_g = np.nan_to_num(naxis_h3p/naxis_mapcount)
        image_g=np.nan_to_num(image_g)**0.5
        # print(np.nanmax(image_g))
        # image_g=image_g/np.nanmax(image_g)
        image_g=image_g/135
        image_r = np.nan_to_num(naxis_heat/naxis_mapcount)  # /np.max(im_methane)
        # print(np.nanmax(image_r))
        # image_r=image_r/np.max(image_r)
        image_r=image_r/140000
        image_b = np.nan_to_num(naxis_reflect/naxis_mapcount)  # /np.max(im_reflect)
        # print(np.nanmax(image_b))
        # image_b=image_b/np.max(image_b)
        image_b=image_b/90000
        image = make_lupton_rgb(image_r*5, image_g*5, image_b*10)

    # for xxx in range(180): aaa[:,xxx]=aaa[:,xxx]/np.nanmax(aaa[:,xxx])


        plt.imshow(image,aspect=1,origin='lower',vmin=30)
        # plt.title(str(i))
    # plt.plot([0,180],[550,550],linestyle='--',color='grey')
        plt.xlim((1*xscale,8*xscale))
        plt.ylim((0.9*yscale,5.5*yscale))
        # plt.title(str(zz))
        plt.savefig(savedir+'saturn_movie'+str(zz)+'.pdf', dpi=300)

        plt.show()

        
        naxis_mapcount[yc,xc]=m_counts
        naxis_h3p[yc,xc]=m_h3p
        naxis_methane[yc,xc]=m_methane
        naxis_heat[yc,xc]=m_heat
        naxis_reflect[yc,xc]=m_reflect
 
 
    # if plot:

    # _plot(px, py, xc, yc, area, slices)
zz=zz+1

image_g = np.nan_to_num(naxis_h3p/naxis_mapcount)
image_g=np.nan_to_num(image_g)**0.5
# print(np.nanmax(image_g))
# image_g=image_g/np.nanmax(image_g)
image_g=image_g/135
image_r = np.nan_to_num(naxis_heat/naxis_mapcount)  # /np.max(im_methane)
# print(np.nanmax(image_r))
# image_r=image_r/np.max(image_r)
image_r=image_r/140000
image_b = np.nan_to_num(naxis_reflect/naxis_mapcount)  # /np.max(im_reflect)
# print(np.nanmax(image_b))
# image_b=image_b/np.max(image_b)
image_b=image_b/90000
image = make_lupton_rgb(image_r*5, image_g*5, image_b*10)
        

# for xxx in range(180): aaa[:,xxx]=aaa[:,xxx]/np.nanmax(aaa[:,xxx])


plt.imshow(image,aspect=1,origin='lower',vmin=30,position=([0,0,1,1]))
plt.xlim((1*xscale,8*xscale))
plt.ylim((0.9*yscale,5.5*yscale))

# plt.title(str(zz))
plt.savefig(savedir+'saturn_movie'+str(zz)+'.pdf', dpi=300)

# plt.plot([0,180],[550,550],linestyle='--',color='grey')
plt.show()

# bbb=np.nanmedian(aaa[0:2000],axis=1)
# plt.plot(bbb)
# plt.show()

# spectral_height_map[:,wavenum] = bbb
# plt.imshow(spectral_height_map,aspect=len(whw)/2000/2)
# plt.show()





# In[9] as above but with four movie panels
from JWSTSolarSystemPointing import get_pixel_polygons as jssp_gpp

from pypolyclip import clip_multi, clip_single
 

xscale=100
yscale=100

#450 for 30 degrees of wrap around either size
# naxis is weird, if it is too small it skips anything outside the range, but can be any size bigger than the range of values clipped

naxis = (10*xscale,10*yscale)

naxis_mapcount=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_h3p=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_methane=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_heat=np.zeros((int(10*xscale+1),int(10*yscale+1)))
naxis_reflect=np.zeros((int(10*xscale+1),int(10*yscale+1)))


files = sorted(glob.glob(
    "/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))

# files=files[10]

lens=len(files)
# lens=5
z0=0
zz=0
x_offset=5
y_offset=-3

savedir="/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/movies/"

for i in range(lens):
    
    
    
    
    # file = 'data/jw03665032001_02101_g395h-f290lp_s3d.fits'
    geo = jssp.JWSTSolarSystemPointing(files[i])
    wave = geo.get_wavelength()
    cube = geo.full_fov()
    spec = geo.convert(wave, geo.im[:, 25, 25])
    cube = geo.full_fov()
    cube_1 = geo.full_fov(corner=1)
    cube_2 = geo.full_fov(corner=2)
    cube_3 = geo.full_fov(corner=3)
    cube_4 = geo.full_fov(corner=4)
    


    ra = (cube[13, :, :]-geo.ra_target)* -3600.0         *xscale +(x_offset*xscale)
    ra_1 = (cube_1[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
    ra_2 = (cube_2[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
    ra_3 = (cube_3[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
    ra_4 = (cube_4[13, :, :]-geo.ra_target)* -3600.0     *xscale +(x_offset*xscale)
   
    dec = (cube[14, :, :]-geo.dec_target)* 3600.0       *yscale +(y_offset*yscale)
    dec_1 = (cube_1[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)
    dec_2 = (cube_2[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)
    dec_3 = (cube_3[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)
    dec_4 = (cube_4[14, :, :]-geo.dec_target)* 3600.0   *yscale +(y_offset*yscale)


    sy=3
    ey=-2
    geoim=geo.im[:, :,sy:ey]
    ra = ra[:,sy:ey]
    dec=dec[:,sy:ey]
    ra_1 = ra_1[:,sy:ey]
    dec_1=dec_1[:,sy:ey]
    ra_2 = ra_2[:,sy:ey]
    dec_2=dec_2[:,sy:ey]
    ra_3 = ra_3[:,sy:ey]
    dec_3=dec_3[:,sy:ey]
    ra_4 = ra_4[:,sy:ey]
    dec_4=dec_4[:,sy:ey]
    wavemin = 3.0
    wavemax = 3.2
    
    # dlambda=0.00015
    #whw = np.argwhere((wave > 3.3529+dlambda) & (wave < 3.3535+dlambda)).flatten()
    whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
    img_reflect = np.nanmedian(geoim[whw, :, :], axis=0)
    
    img_methane = geoim[2775, :, :]
    
    wavemin = 3.9529
    wavemax = 3.9535
    
    whw = np.argwhere((wave > wavemin) & (wave < wavemax)).flatten()
    img_h3p = np.nanmedian(geoim[whw, :, :], axis=0)
    
    ab=geoim[2239,:,:]+geoim[1568,:,:]+geoim[1642,:,:]+geoim[1010,:,:]+geoim[1011,:,:]+geoim[1018,:,:]+geoim[1019,:,:]
    cd=(geoim[2241,:,:]*0.8*0.5+geoim[2236,:,:]*1.19*0.5)+((geoim[1571,:,:]+geoim[1565,:,:])*0.5)+(0.4*geoim[1640,:,:]/5.21)+(geoim[1004,:,:]+geoim[1005,:,:]*2)

    img_h3p = ab-cd
    
    img_heat = geoim[3487,:,:]

    #ra, dec = geo.get_delta_ra_dec_arcsec()
    
    # im_h3p[im_h3p > np.nanmedian(im_h3p)*3] = np.nanmedian(im_h3p)*3
    # im_h3p[im_h3p<0] = 0
    
    i0=0

    for xx in range(ab[:,0].size):
        for yy in range(ab[0,:].size):
        
            
            
            pixel_h3p=img_h3p[xx,yy]
            pixel_methane=img_methane[xx,yy]
            pixel_heat=img_heat[xx,yy]
            pixel_reflect=img_reflect[xx,yy]
            # print(xx,yy,pixel_int)

            if ~np.isnan(pixel_h3p):
                
                pxx = np.array([ra_1[xx,yy],ra_2[xx,yy],ra_3[xx,yy],ra_4[xx,yy]])
                pyy = np.array([dec_1[xx,yy],dec_2[xx,yy],dec_3[xx,yy],dec_4[xx,yy]])
                pii_h3p=pixel_h3p
                pii_methane=pixel_methane
                pii_heat=pixel_heat
                pii_reflect=pixel_reflect


                if i0 == 0:

                    px=pxx
                    py=pyy
                    pi_h3p=pii_h3p
                    pi_methane=pii_methane
                    pi_heat=pii_heat
                    pi_reflect=pii_reflect

                else: #i0 is set to one on first iteration, if not set, then make px and py

                    px= np.vstack([px, pxx])
                    py= np.vstack([py, pyy])
                    pi_h3p= np.hstack([pi_h3p, pii_h3p])
                    pi_methane= np.hstack([pi_methane, pii_methane])
                    pi_heat= np.hstack([pi_heat, pii_heat])
                    pi_reflect= np.hstack([pi_reflect, pii_reflect])

                i0=i0+1
                    # print(xx,yy,pxx,pyy,pii)

                # print(i0,pxx,pyy)

    xc, yc, area, slices = clip_multi(px, py, naxis)
 
#     # compute the total area by summing over all the pixels for each polygon

    A0 = np.asarray([np.sum(area[s]) for s in slices], dtype=float)
 
#     # make arrays to capture the values

    m_counts=np.zeros_like(area)
    m_h3p=np.zeros_like(area)
    m_methane=np.zeros_like(area)
    m_heat=np.zeros_like(area)
    m_reflect=np.zeros_like(area)
 
#     # itrator for slices/pixels (where slices are bunched by pixels)

    ixel=0

    for s in slices: 

        m_counts[s] = area[s]#/np.sum(area[s])
        # pypolyclip gives a value that is the proportion of the lat long pixel (which might be right?), here, I instead use the proportion of the pixel, so all latlong subpixels sum to a count of 1 - might be the wrong way to do this, but biases towards the smallest (most face on) pixels.  Needs tetsing!
        m_h3p[s] = pi_h3p[ixel]*area[s]#(area[s]/np.sum(area[s]))
        m_methane[s] = pi_methane[ixel]*area[s]#(area[s]/np.sum(area[s]))
        m_heat[s] = pi_heat[ixel]*area[s]#(area[s]/np.sum(area[s]))
        m_reflect[s] = pi_reflect[ixel]*area[s]#(area[s]/np.sum(area[s]))
        ixel=ixel+1
 
    # fill individual map positions with the above

    if z0 < 4:
        print(z0)
    # naxis_mapcount[yc,xc]=naxis_mapcount[yc,xc]+m_counts
    # naxis_h3p[yc,xc]=naxis_h3p[yc,xc]+m_int
        naxis_mapcount[yc,xc]=m_counts+naxis_mapcount[yc,xc]
        naxis_h3p[yc,xc]=m_h3p+naxis_h3p[yc,xc]
        naxis_methane[yc,xc]=m_methane+naxis_methane[yc,xc]
        naxis_heat[yc,xc]=m_heat+naxis_heat[yc,xc]
        naxis_reflect[yc,xc]=m_reflect+naxis_reflect[yc,xc]
        z0=z0+1
    else:
        z0=0
        zz=zz+1
        image_g = np.nan_to_num(naxis_h3p/naxis_mapcount)
        image_g=np.nan_to_num(image_g)**0.5
        # print(np.nanmax(image_g))
        # image_g=image_g/np.nanmax(image_g)
        image_g=image_g/135
        image_r = np.nan_to_num(naxis_heat/naxis_mapcount)  # /np.max(im_methane)
        # print(np.nanmax(image_r))
        # image_r=image_r/np.max(image_r)
        image_r=image_r/140000
        image_b = np.nan_to_num(naxis_reflect/naxis_mapcount)  # /np.max(im_reflect)
        # print(np.nanmax(image_b))
        # image_b=image_b/np.max(image_b)
        image_b=image_b/90000
        image = make_lupton_rgb(image_r*5, image_g*5, image_b*10)

    # for xxx in range(180): aaa[:,xxx]=aaa[:,xxx]/np.nanmax(aaa[:,xxx])

                
        fig = plt.figure()
        
        
        a1 = plt.subplot(position=[0,0.5,0.5,0.5])
        # a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
        a1.set_axis_off()        
        
        a1.imshow(image,aspect=1,origin='lower')
        # plt.title(str(i))
        # plt.plot([0,180],[550,550],linestyle='--',color='grey')
        a1.set_xlim((1*xscale,8*xscale))
        a1.set_ylim((0.9*yscale,5.5*yscale))
        
        
        a2 = plt.subplot(position=[0.5,0.5,0.5,0.5])
        # a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
        a2.set_axis_off()        
        
        a2.imshow(image_g,aspect=1,origin='lower',vmin=0)
        # plt.title(str(i))
        # plt.plot([0,180],[550,550],linestyle='--',color='grey')
        a2.set_xlim((1*xscale,8*xscale))
        a2.set_ylim((0.9*yscale,5.5*yscale))
        
        a3 = plt.subplot(position=[0,0,0.5,0.5])
        # a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
        a3.set_axis_off()        
        
        a3.imshow(image_b,aspect=1,origin='lower',vmin=0)
        # plt.title(str(i))
        # plt.plot([0,180],[550,550],linestyle='--',color='grey')
        a3.set_xlim((1*xscale,8*xscale))
        a3.set_ylim((0.9*yscale,5.5*yscale))
        
        a4 = plt.subplot(position=[0.5,0.0,0.5,0.5])
        # a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
        a4.set_axis_off()        
        
        a4.imshow(image_r,aspect=1,origin='lower',vmin=0)
        # plt.title(str(i))
        # plt.plot([0,180],[550,550],linestyle='--',color='grey')
        a4.set_xlim((1*xscale,8*xscale))
        a4.set_ylim((0.9*yscale,5.5*yscale))
        
        fig.set_facecolor("black")
        # plt.title(str(zz))
        plt.savefig(savedir+'saturn_movie'+str(zz)+'.pdf', dpi=300)
        plt.savefig(savedir+'png_saturn_movie'+str(zz)+'.png', dpi=300)
        
        plt.show()
        
        naxis_mapcount[yc,xc]=m_counts
        naxis_h3p[yc,xc]=m_h3p
        naxis_methane[yc,xc]=m_methane
        naxis_heat[yc,xc]=m_heat
        naxis_reflect[yc,xc]=m_reflect
 
 
    # if plot:

    # _plot(px, py, xc, yc, area, slices)
zz=zz+1

image_g = np.nan_to_num(naxis_h3p/naxis_mapcount)
image_g=np.nan_to_num(image_g)**0.5
# print(np.nanmax(image_g))
# image_g=image_g/np.nanmax(image_g)
image_g=image_g/135
image_r = np.nan_to_num(naxis_heat/naxis_mapcount)  # /np.max(im_methane)
# print(np.nanmax(image_r))
# image_r=image_r/np.max(image_r)
image_r=image_r/140000
image_b = np.nan_to_num(naxis_reflect/naxis_mapcount)  # /np.max(im_reflect)
# print(np.nanmax(image_b))
# image_b=image_b/np.max(image_b)
image_b=image_b/90000
image = make_lupton_rgb(image_r*5, image_g*5, image_b*10)
        

# for xxx in range(180): aaa[:,xxx]=aaa[:,xxx]/np.nanmax(aaa[:,xxx])

       
fig = plt.figure()


a1 = plt.subplot(position=[0,0.5,0.5,0.5])
# a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
a1.set_axis_off()        

a1.imshow(image,aspect=1,origin='lower')
# plt.title(str(i))
# plt.plot([0,180],[550,550],linestyle='--',color='grey')
a1.set_xlim((1*xscale,8*xscale))
a1.set_ylim((0.9*yscale,5.5*yscale))


a2 = plt.subplot(position=[0.5,0.5,0.5,0.5])
# a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
a2.set_axis_off()        

a2.imshow(image_g,aspect=1,origin='lower',vmin=0)
# plt.title(str(i))
# plt.plot([0,180],[550,550],linestyle='--',color='grey')
a2.set_xlim((1*xscale,8*xscale))
a2.set_ylim((0.9*yscale,5.5*yscale))

a3 = plt.subplot(position=[0,0,0.5,0.5])
# a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
a3.set_axis_off()        

a3.imshow(image_b,aspect=1,origin='lower',vmin=0)
# plt.title(str(i))
# plt.plot([0,180],[550,550],linestyle='--',color='grey')
a3.set_xlim((1*xscale,8*xscale))
a3.set_ylim((0.9*yscale,5.5*yscale))

a4 = plt.subplot(position=[0.5,0.0,0.5,0.5])
# a31i=a3x1.contourf(lon2d, lat2d, mapmap_scaled2,transform=ccrs.PlateCarree(),levels=256,cmap='afmhot',vmax=maxint,vmin=minint) 
a4.set_axis_off()        

a4.imshow(image_r,aspect=1,origin='lower',vmin=0)
# plt.title(str(i))
# plt.plot([0,180],[550,550],linestyle='--',color='grey')
a4.set_xlim((1*xscale,8*xscale))
a4.set_ylim((0.9*yscale,5.5*yscale))

fig.set_facecolor("black")
# plt.title(str(zz))
plt.savefig(savedir+'saturn_movie'+str(zz)+'.pdf', dpi=300)
plt.savefig(savedir+'png_saturn_movie'+str(zz)+'.png', dpi=300)

plt.show()
